from scipy import io
import matplotlib.pyplot as plt
import pickle as pkl
import numpy as np
from pygenstability.pygenstability import _evaluate_VI

if __name__ == '__main__':
    data = pkl.load(open("hox_gene_expression.pkl", "rb"))[:-1]
    ground_truth_2 = np.unique(data["Neuron Class"].to_list(), return_inverse=True)[1]

    mat_data = io.loadmat('data_matlab.mat')
    scales = mat_data['tg'][0]
    stab = mat_data['stab'][0]
    comms = mat_data['COM_RV'].T

    vis = []
    for _i in range(len(scales)):
        vis.append(_evaluate_VI((0, 1), [comms[_i], ground_truth_2]))

    plt.figure(figsize=(5, 3))
    plt.semilogx(scales, vis, C='C0', label='VI')
    plt.semilogx(scales, stab, c='C1', label='stability')
    plt.axhline(1, ls='--', c='k')
    _id = 19
    plt.axvline(scales[_id])
    plt.axvline(scales[np.argmin(vis)])
    plt.gca().set_ylim(0, 1.05)
    plt.legend()
    plt.xlabel('scale')
    plt.savefig('multiscale_matlab.pdf', bbox_inches='tight')
